"""
Phase 6: Cross-Asset Synthesis
===============================
Kopecky-Taylor "murder-suicide" summary, cross-asset coefficient table,
KAOPEN moderation, forward projections.
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/asset_returns")
MULTILATERAL_DIR = PROJECT_DIR.parent / "multilateral"
sys.path.insert(0, str(MULTILATERAL_DIR / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"
FIGURES_DIR = PROJECT_DIR / "output" / "figures"

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

FOCUS_COUNTRIES = ['JPN', 'DEU', 'USA', 'KOR', 'CHN', 'IND', 'GBR', 'FRA',
                   'BRA', 'AUS', 'ITA', 'ESP', 'TUR', 'MEX', 'ZAF']


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_standardized(df, dep_var, regressors, label):
    """Run model with standardized dep var, return coefs."""
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        return None

    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        return None

    # Standardize dep var
    y_raw = sub[dep_var].values
    y_mean, y_std = y_raw.mean(), y_raw.std()
    if y_std < 1e-10:
        return None
    y = (y_raw - y_mean) / y_std

    gls = PanelGLS()
    gls.fit(y, sub[regressors].values, sub['iso3'].values, sub['year'].values)

    result = {'dep_var': dep_var, 'label': label,
              'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
              'r_squared': gls.r_squared}
    for i, name in enumerate(regressors):
        result[f'coef_{name}'] = gls.beta[i]
        result[f'se_{name}'] = gls.se[i]
        result[f'p_{name}'] = gls.pvalues[i]
    return result


def main():
    print("=" * 70)
    print("PHASE 6: Cross-Asset Synthesis")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "asset_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df)} obs")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['rgdp_growth', 'inflation', 'fiscal_bal_gdp', 'kaopen', 'nfa_gdp_lag']
    controls = [c for c in controls if c in df.columns]
    regressors = demo_vars + controls

    # ── 1. Cross-asset coefficient table (standardized Y) ──
    print("\n1. Cross-asset standardized coefficient table ...")
    asset_specs = [
        ('real_bond_10y', 'Safe rate (10y)'),
        ('real_short_3m', 'Safe rate (3m)'),
        ('term_spread', 'Term spread'),
        ('log_reer', 'REER (log)'),
        ('d_reer', 'REER (Δ%)'),
        ('d_rhpi', 'House prices (Δ%)'),
        ('stock_market_cap_gdp', 'Stock mkt cap/GDP'),
        ('port_eq_assets_gdp', 'Portfolio equity/GDP'),
        ('carry_vs_usa', 'Carry vs USA'),
    ]

    std_results = []
    for dep_var, label in asset_specs:
        r = run_standardized(df, dep_var, regressors, label)
        if r:
            std_results.append(r)
            print(f"  {label:<25} R²={r['r_squared']:.3f}  "
                  f"Z₁={r.get('coef_Z_1', np.nan):.3f}")

    # Also run on OECD
    oecd = df[df['iso3'].isin(OECD_38)].copy()
    for dep_var, label in asset_specs:
        r = run_standardized(oecd, dep_var, regressors, f"OECD: {label}")
        if r:
            std_results.append(r)

    # ── 2. Kopecky-Taylor summary ──
    print("\n2. Kopecky-Taylor murder-suicide summary ...")
    build_kopecky_taylor_table(std_results)

    # ── 3. Cross-asset coefficient table (rows=Z, cols=assets) ──
    print("\n3. Cross-asset coefficient matrix ...")
    build_cross_asset_matrix(std_results, asset_specs)

    # ── 4. KAOPEN moderation across all assets ──
    print("\n4. KAOPEN moderation table ...")
    int_vars = [v for v in ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
                if v in df.columns]
    if int_vars:
        kaopen_results = []
        for dep_var, label in asset_specs:
            r = run_standardized(df, dep_var, regressors + int_vars, label)
            if r:
                kaopen_results.append(r)
        build_kaopen_table(kaopen_results, asset_specs, int_vars)

    # ── 5. Age decomposition (old_dep + youth_dep) ──
    print("\n5. Age decomposition ...")
    age_vars = ['old_dep', 'youth_dep']
    age_regs = age_vars + controls
    age_results = []
    for dep_var, label in asset_specs:
        r = run_standardized(df, dep_var, age_regs, label)
        if r:
            age_results.append(r)
    build_age_decomp_table(age_results, asset_specs)

    # ── 6. Forward projections ──
    print("\n6. Forward projections ...")
    build_projections(df, regressors, asset_specs)

    print("\n" + "=" * 70)
    print("Phase 6 complete.")
    print("=" * 70)


def build_kopecky_taylor_table(std_results):
    """Kopecky-Taylor: Z large negative on safe rates, null on equity."""
    # Filter to full-sample results only
    full_results = [r for r in std_results if not r['label'].startswith('OECD')]

    md = ["# Kopecky-Taylor Murder-Suicide Summary\n"]
    md.append("The 'murder-suicide of the rentier' predicts aging depresses safe "
              "returns (large negative Z₁) while sustaining/increasing the equity "
              "risk premium (small/null Z₁ on equity).\n")
    md.append("| Asset Class | Z₁ (std) | p-value | Z₂ (std) | p-value | N | R² |")
    md.append("|---|---|---|---|---|---|---|")

    for r in full_results:
        z1 = r.get('coef_Z_1', np.nan)
        z1_p = r.get('p_Z_1', np.nan)
        z2 = r.get('coef_Z_2', np.nan)
        z2_p = r.get('p_Z_2', np.nan)
        if np.isnan(z1):
            continue
        md.append(
            f"| {r['label']} | {z1:.3f}{stars(z1_p)} | {z1_p:.4f} "
            f"| {z2:.3f}{stars(z2_p)} | {z2_p:.4f} "
            f"| {r['n_obs']} | {r['r_squared']:.3f} |"
        )

    md.append("\n*Dependent variables standardized (mean=0, sd=1). "
              "PanelGLS with AR(1). Controls: GDP growth, inflation, fiscal balance, "
              "KAOPEN, NFA/GDP lag.*")

    out = TABLES_DIR / "kopecky_taylor.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


def build_cross_asset_matrix(std_results, asset_specs):
    """Rows: Z₁, Z₂, Z₃; Columns: each asset class."""
    full_results = {r['dep_var']: r for r in std_results
                    if not r['label'].startswith('OECD')}

    md = ["# Cross-Asset Demographic Coefficient Matrix\n"]
    md.append("Standardized coefficients from PanelGLS. Each column is a separate regression.\n")

    # Header
    cols = [label for dv, label in asset_specs if dv in full_results]
    md.append("| Variable | " + " | ".join(cols) + " |")
    md.append("|---" + "|---" * len(cols) + "|")

    for zvar in ['Z_1', 'Z_2', 'Z_3', 'old_dep', 'youth_dep']:
        row = [zvar]
        for dv, label in asset_specs:
            if dv not in full_results:
                continue
            r = full_results[dv]
            coef = r.get(f'coef_{zvar}', None)
            p = r.get(f'p_{zvar}', None)
            if coef is not None and not np.isnan(coef):
                row.append(f"{coef:.3f}{stars(p)}")
            else:
                row.append("—")
        md.append("| " + " | ".join(row) + " |")

    # R² row
    row = ['R²']
    for dv, label in asset_specs:
        if dv in full_results:
            row.append(f"{full_results[dv]['r_squared']:.3f}")
    md.append("| " + " | ".join(row) + " |")

    # N row
    row = ['N']
    for dv, label in asset_specs:
        if dv in full_results:
            row.append(f"{full_results[dv]['n_obs']}")
    md.append("| " + " | ".join(row) + " |")

    md.append("\n*\\*p<0.10, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    out = TABLES_DIR / "cross_asset_matrix.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


def build_kaopen_table(kaopen_results, asset_specs, int_vars):
    """KAOPEN moderation across all assets."""
    results_map = {r['dep_var']: r for r in kaopen_results}

    md = ["# KAOPEN Moderation Across Asset Classes\n"]
    md.append("| Asset Class | Z₁×KAOPEN | p | Z₂×KAOPEN | p | Z₃×KAOPEN | p |")
    md.append("|---|---|---|---|---|---|---|")

    for dv, label in asset_specs:
        if dv not in results_map:
            continue
        r = results_map[dv]
        cells = [label]
        for iv in int_vars:
            coef = r.get(f'coef_{iv}', np.nan)
            p = r.get(f'p_{iv}', np.nan)
            if not np.isnan(coef):
                cells.extend([f"{coef:.3f}{stars(p)}", f"{p:.3f}"])
            else:
                cells.extend(["—", "—"])
        md.append("| " + " | ".join(cells) + " |")

    out = TABLES_DIR / "kaopen_moderation.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


def build_age_decomp_table(age_results, asset_specs):
    """old_dep vs youth_dep across assets."""
    results_map = {r['dep_var']: r for r in age_results}

    md = ["# Age Decomposition: Old vs Young Dependency\n"]
    md.append("| Asset Class | old_dep | p | youth_dep | p | N | R² |")
    md.append("|---|---|---|---|---|---|---|")

    for dv, label in asset_specs:
        if dv not in results_map:
            continue
        r = results_map[dv]
        old_c = r.get('coef_old_dep', np.nan)
        old_p = r.get('p_old_dep', np.nan)
        yth_c = r.get('coef_youth_dep', np.nan)
        yth_p = r.get('p_youth_dep', np.nan)
        if np.isnan(old_c):
            continue
        md.append(
            f"| {label} | {old_c:.3f}{stars(old_p)} | {old_p:.3f} "
            f"| {yth_c:.3f}{stars(yth_p)} | {yth_p:.3f} "
            f"| {r['n_obs']} | {r['r_squared']:.3f} |"
        )

    out = TABLES_DIR / "age_decomposition.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


def build_projections(df, regressors, asset_specs):
    """Project demographic pressure on each asset class for focus countries."""
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = [c for c in regressors if c not in demo_vars]

    # Get latest Z values for each country
    latest = df.sort_values('year').groupby('iso3').last().reset_index()
    focus = latest[latest['iso3'].isin(FOCUS_COUNTRIES)].copy()

    if len(focus) == 0:
        print("  No focus countries found — skipping projections")
        return

    # Estimate models and compute demographic contribution
    proj_rows = []
    for dep_var, label in asset_specs:
        avail_regs = [r for r in regressors if r in df.columns]
        sub = df.dropna(subset=[dep_var] + avail_regs).copy()
        if len(sub) < 50:
            continue

        gls = PanelGLS()
        gls.fit(sub[dep_var].values, sub[avail_regs].values,
                sub['iso3'].values, sub['year'].values)

        # Compute demographic component = Z₁·β₁ + Z₂·β₂ + Z₃·β₃
        z_coefs = {}
        for i, name in enumerate(avail_regs):
            if name in demo_vars:
                z_coefs[name] = gls.beta[i]

        for _, row in focus.iterrows():
            demo_effect = sum(z_coefs.get(zv, 0) * row.get(zv, 0) for zv in demo_vars)
            proj_rows.append({
                'iso3': row['iso3'],
                'asset': label,
                'dep_var': dep_var,
                'demo_effect': demo_effect,
                'latest_year': row['year'],
            })

    proj_df = pd.DataFrame(proj_rows)
    if len(proj_df) == 0:
        return

    # Pivot: rows=countries, cols=assets
    pivot = proj_df.pivot(index='iso3', columns='asset', values='demo_effect')
    pivot = pivot.reindex(FOCUS_COUNTRIES).dropna(how='all')

    md = ["# Forward Projections: Demographic Pressure on Asset Classes\n"]
    md.append("Demographic component (Z₁β₁ + Z₂β₂ + Z₃β₃) using latest demographics.\n")

    cols = list(pivot.columns)
    md.append("| Country | " + " | ".join(cols) + " |")
    md.append("|---" + "|---" * len(cols) + "|")
    for iso3 in pivot.index:
        cells = [iso3]
        for col in cols:
            val = pivot.loc[iso3, col]
            if pd.notna(val):
                cells.append(f"{val:.3f}")
            else:
                cells.append("—")
        md.append("| " + " | ".join(cells) + " |")

    md.append("\n*Positive = demographic pressure pushing asset value up; "
              "negative = downward pressure.*")

    out = TABLES_DIR / "projections.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")

    # Also save CSV
    proj_df.to_csv(TABLES_DIR / "projections.csv", index=False)


if __name__ == "__main__":
    main()
